{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
    "os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/home/husein/t5/prepare/mesolitica-tpu.json'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import itertools\n",
    "from glob import glob\n",
    "from multiprocessing import Pool\n",
    "from google.cloud import storage\n",
    "import tensorflow as tf\n",
    "import json\n",
    "import regex as re\n",
    "from functools import lru_cache\n",
    "import tensorflow as tf\n",
    "import gpt_2_simple\n",
    "from tqdm import tqdm\n",
    "import collections\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "@lru_cache()\n",
    "def bytes_to_unicode():\n",
    "    \"\"\"\n",
    "    Returns list of utf-8 byte and a corresponding list of unicode strings.\n",
    "    The reversible bpe codes work on unicode strings.\n",
    "    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n",
    "    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n",
    "    This is a signficant percentage of your normal, say, 32K bpe vocab.\n",
    "    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n",
    "    And avoids mapping to whitespace/control characters the bpe code barfs on.\n",
    "    \"\"\"\n",
    "    bs = (\n",
    "        list(range(ord('!'), ord('~') + 1))\n",
    "        + list(range(ord('¡'), ord('¬') + 1))\n",
    "        + list(range(ord('®'), ord('ÿ') + 1))\n",
    "    )\n",
    "    cs = bs[:]\n",
    "    n = 0\n",
    "    for b in range(2 ** 8):\n",
    "        if b not in bs:\n",
    "            bs.append(b)\n",
    "            cs.append(2 ** 8 + n)\n",
    "            n += 1\n",
    "    cs = [chr(n) for n in cs]\n",
    "    return dict(zip(bs, cs))\n",
    "\n",
    "\n",
    "def get_pairs(word):\n",
    "    \"\"\"Return set of symbol pairs in a word.\n",
    "    Word is represented as tuple of symbols (symbols being variable-length strings).\n",
    "    \"\"\"\n",
    "    pairs = set()\n",
    "    prev_char = word[0]\n",
    "    for char in word[1:]:\n",
    "        pairs.add((prev_char, char))\n",
    "        prev_char = char\n",
    "    return pairs\n",
    "\n",
    "\n",
    "class Encoder:\n",
    "    def __init__(self, encoder, bpe_merges, errors='replace'):\n",
    "        self.encoder = encoder\n",
    "        self.decoder = {v: k for k, v in self.encoder.items()}\n",
    "        self.errors = errors  # how to handle errors in decoding\n",
    "        self.byte_encoder = bytes_to_unicode()\n",
    "        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n",
    "        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))\n",
    "        self.cache = {}\n",
    "\n",
    "        # Should haved added re.IGNORECASE so BPE merges can happen for\n",
    "        # capitalized versions of contractions\n",
    "        self.pat = re.compile(\n",
    "            r\"\"\"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+\"\"\"\n",
    "        )\n",
    "\n",
    "    def bpe(self, token):\n",
    "        if token in self.cache:\n",
    "            return self.cache[token]\n",
    "        word = tuple(token)\n",
    "        pairs = get_pairs(word)\n",
    "\n",
    "        if not pairs:\n",
    "            return token\n",
    "\n",
    "        while True:\n",
    "            bigram = min(\n",
    "                pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))\n",
    "            )\n",
    "            if bigram not in self.bpe_ranks:\n",
    "                break\n",
    "            first, second = bigram\n",
    "            new_word = []\n",
    "            i = 0\n",
    "            while i < len(word):\n",
    "                try:\n",
    "                    j = word.index(first, i)\n",
    "                    new_word.extend(word[i:j])\n",
    "                    i = j\n",
    "                except BaseException:\n",
    "                    new_word.extend(word[i:])\n",
    "                    break\n",
    "\n",
    "                if (\n",
    "                    word[i] == first\n",
    "                    and i < len(word) - 1\n",
    "                    and word[i + 1] == second\n",
    "                ):\n",
    "                    new_word.append(first + second)\n",
    "                    i += 2\n",
    "                else:\n",
    "                    new_word.append(word[i])\n",
    "                    i += 1\n",
    "            new_word = tuple(new_word)\n",
    "            word = new_word\n",
    "            if len(word) == 1:\n",
    "                break\n",
    "            else:\n",
    "                pairs = get_pairs(word)\n",
    "        word = ' '.join(word)\n",
    "        self.cache[token] = word\n",
    "        return word\n",
    "\n",
    "    def encode(self, text):\n",
    "        bpe_tokens = []\n",
    "        for token in re.findall(self.pat, text):\n",
    "            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n",
    "            bpe_tokens.extend(\n",
    "                self.encoder[bpe_token]\n",
    "                for bpe_token in self.bpe(token).split(' ')\n",
    "            )\n",
    "        return bpe_tokens\n",
    "\n",
    "    def decode(self, tokens):\n",
    "        text = ''.join([self.decoder[token] for token in tokens])\n",
    "        text = bytearray([self.byte_decoder[c] for c in text]).decode(\n",
    "            'utf-8', errors=self.errors\n",
    "        )\n",
    "        return text\n",
    "\n",
    "\n",
    "def create_int_feature(values):\n",
    "    feature = tf.train.Feature(\n",
    "        int64_list=tf.train.Int64List(value=list(values))\n",
    "    )\n",
    "    return feature\n",
    "\n",
    "\n",
    "def write_tfrecord(s, file):\n",
    "    r = tf.python_io.TFRecordWriter(file)\n",
    "    for i in tqdm(range(len(s))):\n",
    "        features = collections.OrderedDict()\n",
    "        features['input_ids'] = create_int_feature(s[i])\n",
    "        tf_example = tf.train.Example(\n",
    "            features=tf.train.Features(feature=features)\n",
    "        )\n",
    "        r.write(tf_example.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('encoder.json', 'r') as f:\n",
    "    en = json.load(f)\n",
    "with open('vocab.bpe', 'r', encoding=\"utf-8\") as f:\n",
    "    bpe_data = f.read()\n",
    "\n",
    "bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\\n')[1:-1]]\n",
    "enc_malay = Encoder(\n",
    "    encoder=en,\n",
    "    bpe_merges=bpe_merges,\n",
    ")\n",
    "\n",
    "length = 1024\n",
    "combine = 50000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/huseinzol05/malay-dataset/master/dumping/karangan-sekolah/karangan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_chunks = []\n",
    "raw_text = ''\n",
    "with open('karangan', 'r', encoding='utf8', errors='ignore') as fp:\n",
    "    raw_text += fp.read()\n",
    "\n",
    "if len(raw_text) >= combine:\n",
    "    tokens = enc_malay.encode(raw_text)\n",
    "    token_chunks.append(tokens)\n",
    "    raw_text = ''\n",
    "else:\n",
    "    raw_text += '<|endoftext|>'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "s = []\n",
    "for l in range(len(token_chunks)):\n",
    "    for i in range(0, len(token_chunks[l]), length):\n",
    "        index = min(i + length, len(token_chunks[l]))\n",
    "        x = token_chunks[l][i: index]\n",
    "        if len(x) != length:\n",
    "            x = token_chunks[l][index - length: index]\n",
    "        s.append(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "82"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 82/82 [00:00<00:00, 2615.16it/s]\n"
     ]
    }
   ],
   "source": [
    "output_file = 'karangan.tfrecord'\n",
    "write_tfrecord(s, output_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = storage.Client()\n",
    "bucket = client.bucket('mesolitica-tpu-general')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "blob = bucket.blob(f'gpt2-testset/{output_file}')\n",
    "blob.upload_from_filename(output_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm {output_file}"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
